{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malay-dataset/summarization/cnn-news/translated-cnn-test.json\n",
    "# !wget https://f000.backblazeb2.com/file/malay-dataset/summarization/dailymail/translated-dailymail-test.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/malaya/malaya/pretrained-model/pegasus/tokenization.py:135: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import tokenization\n",
    "\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "    vocab_file='pegasus.wordpiece', do_lower_case=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('output/frozen_model.pb')\n",
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "top_p = g.get_tensor_by_name('import/top_p:0')\n",
    "temperature = g.get_tensor_by_name('import/temperature:0')\n",
    "logits = g.get_tensor_by_name('import/logits:0')\n",
    "test_sess = tf.InteractiveSession(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('translated-cnn-test.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "    \n",
    "with open('translated-dailymail-test.json') as fopen:\n",
    "    data.extend(json.load(fopen))\n",
    "    \n",
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y = [], []\n",
    "for i in range(len(data)):\n",
    "    X.append(' '.join(data[i]['ms_article']))\n",
    "    Y.append(' '.join(data[i]['ms_abstract']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from unidecode import unidecode\n",
    "from malaya.text.rules import normalized_chars\n",
    "\n",
    "def filter_news(string):\n",
    "    string = string.lower()\n",
    "    return 'javascript is disabled' in string or 'requires javascript' in string or 'javascript' in string \\\n",
    "    or 'président' in string\n",
    "\n",
    "def make_cleaning(s, c_dict):\n",
    "    s = s.translate(c_dict)\n",
    "    return s\n",
    "\n",
    "def transformer_textcleaning(string):\n",
    "    \"\"\"\n",
    "    use by any transformer model before tokenization\n",
    "    \"\"\"\n",
    "    string = unidecode(string)\n",
    "    string = ' '.join(\n",
    "        [make_cleaning(w, normalized_chars) for w in string.split()]\n",
    "    )\n",
    "    string = re.sub('\\(dot\\)', '.', string)\n",
    "    string = (\n",
    "        re.sub(re.findall(r'\\<a(.*?)\\>', string)[0], '', string)\n",
    "        if (len(re.findall(r'\\<a (.*?)\\>', string)) > 0)\n",
    "        and ('href' in re.findall(r'\\<a (.*?)\\>', string)[0])\n",
    "        else string\n",
    "    )\n",
    "    string = re.sub(\n",
    "        r'\\w+:\\/{2}[\\d\\w-]+(\\.[\\d\\w-]+)*(?:(?:\\/[^\\s/]*))*', ' ', string\n",
    "    )\n",
    "    string = string.replace('\\n', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip().split()\n",
    "    string = [w for w in string if w[0] != '@']\n",
    "    return ' '.join(string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1200/1200 [1:14:36<00:00,  3.73s/it]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "batch_size = 10\n",
    "results = []\n",
    "for i in tqdm(range(0, len(X), batch_size)):\n",
    "    batch_x = X[i: i + batch_size]\n",
    "    batches = []\n",
    "    for b in batch_x:\n",
    "        encoded = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(f'{transformer_textcleaning(b)}'))\n",
    "        encoded = encoded[:512] \n",
    "        batches.append(encoded)\n",
    "    batches = pad_sequences(batches, padding='post')\n",
    "    g = test_sess.run(logits, feed_dict = {x:batches, top_p: 0.0, temperature: 0.0})\n",
    "    results.extend(g.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import rouge\n",
    "from tensorflow.keras.preprocessing import sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_rouges(predicted, batch_y, n_size = 2):\n",
    "    batch_y = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(f'{transformer_textcleaning(row)}')) for row in batch_y]\n",
    "    maxlen = max(max(len(row) for row in batch_y), max(len(row) for row in predicted))\n",
    "    batch_y = sequence.pad_sequences(batch_y, padding = 'post', maxlen = maxlen)\n",
    "    predicted = sequence.pad_sequences(predicted, padding = 'post', maxlen = maxlen)\n",
    "    \n",
    "    non = np.count_nonzero(batch_y, axis = 1)\n",
    "    o = []\n",
    "    for n in non:\n",
    "        o.append([True for _ in range(n)])\n",
    "    b = sequence.pad_sequences(o, dtype = np.bool, padding = 'post', value = False, maxlen = maxlen)\n",
    "    rouges = []\n",
    "    for i in range(b.shape[0]):\n",
    "        a = batch_y[i][b[i]]\n",
    "        p = predicted[i][b[i]]\n",
    "        score = rouge.rouge_n([a], [p], n = n_size)\n",
    "        rouges.append(score)\n",
    "    return np.mean(rouges)\n",
    "\n",
    "def calculate_rouge_l(predicted, batch_y):\n",
    "    batch_y = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(f'{transformer_textcleaning(row)}')) for row in batch_y]\n",
    "    maxlen = max(max(len(row) for row in batch_y), max(len(row) for row in predicted))\n",
    "    batch_y = sequence.pad_sequences(batch_y, padding = 'post', maxlen = maxlen)\n",
    "    predicted = sequence.pad_sequences(predicted, padding = 'post', maxlen = maxlen)\n",
    "    \n",
    "    non = np.count_nonzero(batch_y, axis = 1)\n",
    "    o = []\n",
    "    for n in non:\n",
    "        o.append([True for _ in range(n)])\n",
    "    b = sequence.pad_sequences(o, dtype = np.bool, padding = 'post', value = False, maxlen = maxlen)\n",
    "    rouges = []\n",
    "    for i in range(b.shape[0]):\n",
    "        a = batch_y[i][b[i]]\n",
    "        p = predicted[i][b[i]]\n",
    "        score = rouge.rouge_l_sentence_level([a], [p])\n",
    "        rouges.append(score)\n",
    "    return np.mean(rouges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.29012334"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calculate_rouges(results, Y[:len(results)], n_size = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.11878814"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calculate_rouges(results, Y[:len(results)], n_size = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.19232224"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calculate_rouge_l(results, Y[:len(results)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
